{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fe81dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "import IPython.display as display\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "\n",
    "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"\n",
    "# 0 = all messages are logged (default behavior)\n",
    "# 1 = INFO messages are not printed\n",
    "# 2 = INFO and WARNING messages are not printed\n",
    "# 3 = INFO, WARNING, and ERROR messages are not printed\n",
    "\n",
    "# On Mac you may encounter an error related to OMP, this is a workaround, but slows down the code\n",
    "# https://github.com/dmlc/xgboost/issues/1715\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"True\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132953ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ac37fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openbot import dataloader, data_augmentation, utils, train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c28641",
   "metadata": {},
   "source": [
    "## Set train and test dirs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc941e8e",
   "metadata": {},
   "source": [
    "Define the dataset directory and give it a name. Inside the dataset folder, there should be two folders, `train_data` and `test_data`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "823ef8f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir = \"dataset\"\n",
    "dataset_name = \"my_openbot\"\n",
    "train_data_dir = os.path.join(dataset_dir, \"train_data\")\n",
    "test_data_dir = os.path.join(dataset_dir, \"test_data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4388bbaa",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "625eb2bd",
   "metadata": {},
   "source": [
    "You may have to tune the learning rate and batch size depending on your available compute resources and dataset. As a general rule of thumb, if you increase the batch size by a factor of n, you can increase the learning rate by a factor of sqrt(n). In order to accelerate training and make it more smooth, you should increase the batch size as much as possible. In our paper we used a batch size of 128. For debugging and hyperparamter tuning, you can set the number of epochs to a small value like 10. If you want to train a model which will achieve good performance, you should set it to 50 or more. In our paper we used 100."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a14c7cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = train.Hyperparameters()\n",
    "\n",
    "params.MODEL = \"pilot_net\"\n",
    "params.TRAIN_BATCH_SIZE = 16\n",
    "params.TEST_BATCH_SIZE = 16\n",
    "params.LEARNING_RATE = 0.0001\n",
    "params.NUM_EPOCHS = 10\n",
    "params.BATCH_NORM = True\n",
    "params.FLIP_AUG = False\n",
    "params.CMD_AUG = False\n",
    "params.USE_LAST = False\n",
    "params.WANDB = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ac0929",
   "metadata": {},
   "source": [
    "## Pre-process the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b4494b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr = train.Training(params)\n",
    "tr.train_data_dir = train_data_dir\n",
    "tr.test_data_dir = test_data_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfc9b9b",
   "metadata": {},
   "source": [
    "Running this for the first time will take some time. This code will match image frames to the controls (labels) and indicator signals (commands).  By default, data samples where the vehicle was stationary will be removed. If this is not desired, you need to set `tr.remove_zeros = False`. If you have made any changes to the sensor files, changed `remove_zeros` or moved your dataset to a new directory, you need to set `tr.redo_matching = True`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9d78141",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr.redo_matching = False\n",
    "tr.remove_zeros = True\n",
    "train.process_data(tr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c531aecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "\n",
    "\n",
    "def broadcast(event, payload=None):\n",
    "    print(event, payload)\n",
    "\n",
    "\n",
    "event = threading.Event()\n",
    "my_callback = train.MyCallback(broadcast, event)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc670c9",
   "metadata": {},
   "source": [
    "In the next step, you can convert your dataset to a tfrecord, a data format optimized for training. You can skip this step if you already created a tfrecord before or if you want to train using the files directly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59336936",
   "metadata": {},
   "outputs": [],
   "source": [
    "train.create_tfrecord(my_callback)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0a4a1cf",
   "metadata": {},
   "source": [
    "## Load the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a51ee383",
   "metadata": {},
   "source": [
    "If you did not create a tfrecord and want to load and buffer files from disk directly, set `no_tf_record = True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc2d2068",
   "metadata": {},
   "outputs": [],
   "source": [
    "no_tf_record = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082c90bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "if no_tf_record:\n",
    "    tr.train_data_dir = train_data_dir\n",
    "    tr.test_data_dir = test_data_dir\n",
    "    train.load_data(tr, verbose=0)\n",
    "else:\n",
    "    tr.train_data_dir = os.path.join(dataset_dir, \"tfrecords/train.tfrec\")\n",
    "    tr.test_data_dir = os.path.join(dataset_dir, \"tfrecords/test.tfrec\")\n",
    "    train.load_tfrecord(tr, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf34f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "(image_batch, cmd_batch), label_batch = next(iter(tr.train_ds))\n",
    "utils.show_train_batch(image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a5d0f77",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894cd54f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train.do_training(tr, my_callback, verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cfd4aac",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f623b65f",
   "metadata": {},
   "source": [
    "The loss and mean absolute error should decrease. This indicates that the model is fitting the data well. The custom metrics (direction and angle) should go towards 1. These provide some additional insight to the training progress. The direction metric measures weather or not predictions are in the same direction as the labels. Similarly the angle metric measures if the prediction is within a small angle of the labels. The intuition is that driving in the right direction with the correct steering angle is most critical part for good final performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efd65e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(tr.history.history[\"loss\"], label=\"loss\")\n",
    "plt.plot(tr.history.history[\"val_loss\"], label=\"val_loss\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"loss.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e98f63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(tr.history.history[\"mean_absolute_error\"], label=\"mean_absolute_error\")\n",
    "plt.plot(tr.history.history[\"val_mean_absolute_error\"], label=\"val_mean_absolute_error\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Mean Absolute Error\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"error.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1752d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(tr.history.history[\"direction_metric\"], label=\"direction_metric\")\n",
    "plt.plot(tr.history.history[\"val_direction_metric\"], label=\"val_direction_metric\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Direction Metric\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"direction.png\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de04eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(tr.history.history[\"angle_metric\"], label=\"angle_metric\")\n",
    "plt.plot(tr.history.history[\"val_angle_metric\"], label=\"val_angle_metric\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Angle Metric\")\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.savefig(os.path.join(tr.log_path, \"angle.png\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e72c1de3",
   "metadata": {},
   "source": [
    "Save tf lite models for best and last checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a1ef62",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_index = np.argmax(\n",
    "    np.array(tr.history.history[\"val_angle_metric\"])\n",
    "    + np.array(tr.history.history[\"val_direction_metric\"])\n",
    ")\n",
    "best_checkpoint = str(\"cp-%04d.ckpt\" % (best_index + 1))\n",
    "best_tflite = utils.generate_tflite(tr.checkpoint_path, best_checkpoint)\n",
    "utils.save_tflite(best_tflite, tr.checkpoint_path, \"best\")\n",
    "print(\n",
    "    \"Best Checkpoint (val_angle: %s, val_direction: %s): %s\"\n",
    "    % (\n",
    "        tr.history.history[\"val_angle_metric\"][best_index],\n",
    "        tr.history.history[\"val_direction_metric\"][best_index],\n",
    "        best_checkpoint,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb605b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "last_checkpoint = sorted(\n",
    "    [\n",
    "        d\n",
    "        for d in os.listdir(tr.checkpoint_path)\n",
    "        if os.path.isdir(os.path.join(tr.checkpoint_path, d))\n",
    "    ]\n",
    ")[-1]\n",
    "last_tflite = utils.generate_tflite(tr.checkpoint_path, last_checkpoint)\n",
    "utils.save_tflite(last_tflite, tr.checkpoint_path, \"last\")\n",
    "print(\n",
    "    \"Last Checkpoint (val_angle: %s, val_direction: %s): %s\"\n",
    "    % (\n",
    "        tr.history.history[\"val_angle_metric\"][-1],\n",
    "        tr.history.history[\"val_direction_metric\"][-1],\n",
    "        last_checkpoint,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0c57018",
   "metadata": {},
   "source": [
    "Evaluate the best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a03c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model = utils.load_model(\n",
    "    os.path.join(tr.checkpoint_path, best_checkpoint),\n",
    "    tr.loss_fn,\n",
    "    tr.metric_list,\n",
    "    tr.custom_objects,\n",
    ")\n",
    "test_loss, test_acc, test_dir, test_ang = best_model.evaluate(\n",
    "    tr.test_ds,\n",
    "    steps=tr.image_count_test / tr.hyperparameters.TEST_BATCH_SIZE,\n",
    "    verbose=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e01e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLES = 15\n",
    "(image_batch, cmd_batch), label_batch = next(iter(tr.test_ds))\n",
    "pred_batch = best_model.predict(\n",
    "    (\n",
    "        tf.slice(image_batch, [0, 0, 0, 0], [NUM_SAMPLES, -1, -1, -1]),\n",
    "        tf.slice(cmd_batch, [0], [NUM_SAMPLES]),\n",
    "    )\n",
    ")\n",
    "utils.show_test_batch(\n",
    "    image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daf9dbac",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.compare_tf_tflite(best_model, best_tflite)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29c9d606",
   "metadata": {},
   "source": [
    "## Save the notebook as HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c47915f",
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.save_notebook()\n",
    "current_file = \"policy_learning.ipynb\"\n",
    "output_file = os.path.join(tr.log_path, \"notebook.html\")\n",
    "utils.output_HTML(current_file, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa902ebd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
